"""
@author: Zhenjiao Jiang <zhenjiao.jiang@csiro.au>, October 2018.

"""
import os
import numpy as np
import flopy.modflow as mf
import flopy.mt3d as mt
import flopy.seawat as swt

import flopy.utils as fu
import flopy.utils.binaryfile as bf

#import time

# model domain
Lx,Ly = 100.,100.
ztop,zbot = 100.,0.
nlay,nrow,ncol = 64, 64, 64
center=int(nlay/2)
delr = Lx/ncol
delc = Ly/nrow
botm = np.linspace(ztop, zbot, nlay + 1)

# time steps and simulation period 
nper=5
#nstp=[60,40,60,40,60]
#perlen=[3.0,2.0,3.0,2.0,3.0]

nstp=[20,20,20,20,20]
perlen=[1.0,1.0,1.0,1.0,1.0]

state=[False, False, False, False, False]

# parameters for fluid flow
Sy=0.01     #-
Ss=1.e-5   #1/mhisto
laytyp=1   # convertible from confined to unconfined 

# parameter conversion for heat transport 
lamda_solid=3.59                    # W/moC 1.73-3.98 for granite
lamda_water=0.58                    # W/moC 
rho_water=1000.0                    # kg/m3
rho_solid=2750.0                    # kg/m3 
HC_water=4186                       # heat capacity of water J/kg/oC
HC_solid=790                        # heat capacity of granite J/kg/oC

    
def HF_model (modelname,workspace,Kh,Kv,porosity,rho_bulk,Diff_T,Retard_T):
 
    # create a seawat model 
    swt_model = swt.Seawat(modelname, exe_name='swt_v4', model_ws=workspace) 
    
    # discretization
    dis=mf.ModflowDis(swt_model, nlay, nrow, ncol, delr=delr, delc=delc,top=ztop, botm=botm[1:], 
                        nper=nper, steady=state, perlen=perlen, nstp=nstp)
    
    # define the hydraulic parameters for each layer
    mf.ModflowLpf(swt_model, hk=Kh, vka=Kv, sy=Sy, ss=Ss, laytyp=laytyp,)
    
    # initialize the hydraulic boudaries and initial water head 
    ibound = np.ones((nlay, nrow, ncol), dtype=np.int32)
    strt = np.ones((nlay, nrow, ncol), dtype=np.float32)*3000.
    mf.ModflowBas(swt_model, ibound=ibound, strt=strt)
    
    # define boudary conditions by constant head/pressure 
    well_data={0:[center,center,center,2800,2800],       # equilibrium to a suction pressure of 5MPa
               1:[center,center,center,3000,3000],       # equilirium to a injection pressure of 0 MPa
               2:[center,center,center,2850,2850],
               3:[center,center,center,3000,3000],
               4:[center,center,center,2900,2900]} 
    
    mf.ModflowChd(swt_model, stress_period_data=well_data)
    
    # output setting 
    stress_period_data = {}
    for kper in range(nper):
        for kstp in range(int(20)): #300
            if kstp%2==0:
                stress_period_data[(kper, kstp)] = ['save head']
            
    mf.ModflowOc(swt_model, stress_period_data=stress_period_data,compact=True)

    mf.ModflowPcg(swt_model)
 
    #===========================solute transport ==================================
    # initial temperature distribution
    sconc_1 = np.ones([nlay, nrow, ncol]) * 90.0
    timprs = np.arange(0.1, 5., 0.1)
   
    mt.Mt3dBtn(swt_model, ncomp=1, sconc=sconc_1, prsity = porosity, nprs=40, timprs = timprs) 
    
    # advection
    mt.Mt3dAdv(swt_model,mixelm = 0, percel = 1.,nadvfd=2)
    
    # dispersion and diffusion coefficients 
    mt.Mt3dDsp(swt_model, al=50., dmcoef=Diff_T, multiDiff=True)
    
    # heat exchange between solid and fluid define in reaction part as retardness coefficeint Kd
    mt.Mt3dRct(swt_model, isothm=1, igetsc=0, rhob=rho_bulk, sp1=Retard_T, sp2=0., rc1=0., rc2=0.)
    
    mt.Mt3dGcg(swt_model, cclose = 1e-8)
    
    itype = mt.Mt3dSsm.itype_dict()
    
    ssm_data={0:[center,center,center, -1, itype["CHD"]],
              1:[center,center,center, 25, itype["CHD"]],
              2:[center,center,center, -1, itype["CHD"]],
              3:[center,center,center, 25, itype["CHD"]],
              4:[center,center,center, -1, itype["CHD"]]}
#    
    mt.Mt3dSsm(swt_model, stress_period_data=ssm_data)
    
    
    #================================== seawat updating density ===================
    swt.SeawatVdf(swt_model, mtdnconc=1, mfnadvfd=1, nswtcpl=1, iwtable=1, densemin=0., 
                  densemax=0., dnscrit=0.01, denseref=1.0, denseslp=-3.75e-4, firstdt=0.001)
    
    swt.SeawatVsc(swt_model, mt3dmuflg=-1, viscmin=0.0, viscmax=0.0, viscref=0.0008904, 
                  mutempopt=1, mtmutempspec=1, amucoeff=[2.394e-5, 10.,248.37,133.15])  
    
    swt_model.write_input()
    
    swt_model.run_model(silent=True)
    
    return dis


def extract(workspace,K,Phi):
    
#    start_time = time.time()
#    
    if not os.path.exists(workspace):
        os.makedirs(workspace)
    modelname="EGS"
    try:
        os.remove(os.path.join(workspace, 'MT3D001.UCN'))
        os.remove(os.path.join(workspace, modelname + '.hds'))
        os.remove(os.path.join(workspace, modelname + '.cbc'))
    except:
        pass

   
    Kh=K    # m/d
    Kv=Kh   # m/d
    porosity=Phi
    rho_bulk=rho_solid*(1-porosity)     # bulk density (solid mass divided by total volume)
    lamda_bulk=lamda_solid*(1-porosity)+lamda_water*porosity
    Diff_T=lamda_bulk/porosity/rho_water/HC_water*3600*24     # m/d
    Retard_T=HC_solid/HC_water/rho_water  # heat exchange between solid and fluid part

    HF_model(modelname,workspace,Kh,Kv,porosity,rho_bulk,Diff_T,Retard_T)

    idx=(center,center,center)
    conc = fu.UcnFile(os.path.join(workspace,'MT3D001.UCN'))  
    temperature=conc.get_ts(idx)
    idxL=(center-1,center,center,)
    idxR=(center+1,center,center,)
    idxTo=(center,center+1,center,)
    idxBo=(center,center-1,center,)
    idxFr=(center,center,center-1)
    idxBc=(center,center,center+1)

    kh=(Kh[idxL]+Kh[idxR]+Kh[idxTo]+Kh[idxBo]+Kh[idxFr]+Kh[idxBc]+Kh[idx])/7
    
    hds = bf.HeadFile(os.path.join(workspace,modelname+'.hds'))
    t=hds.get_ts(idx)[:,0] 
    h0=hds.get_ts(idx)[:,1]   
    hL=hds.get_ts(idxL)[:,1]    
    hR=hds.get_ts(idxR)[:,1]  
    hTo=hds.get_ts(idxTo)[:,1]   
    hBo=hds.get_ts(idxBo)[:,1]   
    hFr=hds.get_ts(idxFr)[:,1]   
    hBc=hds.get_ts(idxBc)[:,1]   
    q1=kh*(hL-h0)/delr
    q2=kh*(hR-h0)/delr
    q3=kh*(hTo-h0)/delr
    q4=kh*(hBo-h0)/delr
    q5=kh*(hFr-h0)/delr
    q6=kh*(hBc-h0)/delr
    q_total=q1+q2+q3+q4+q5+q6
    Q=np.concatenate((t.reshape(t.size,1),q_total.reshape(q_total.size,1)),axis=1)
    
    H=hds.get_ts(idx)
    
#    end_time = time.time()
#    print("This sampling run took %5.4f seconds." % (end_time - start_time))
    
    return temperature, Q, H
    
  

